{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Lecture 5: Convolutional networks\n",
    "\n",
    "Notebook adapted from [Deep Learning (with PyTorch)](https://github.com/Atcold/pytorch-Deep-Learning) by Alfredo Canziani. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cpu'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf = transforms.Compose([transforms.ToTensor(),\n",
    "                         transforms.Normalize((0.1307,), (0.3081,))])\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(datasets.MNIST(\"./data\", train=True, transform=tf),\n",
    "                                           batch_size=64, shuffle=True)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(datasets.MNIST(\"./data\", train=False, transform=tf),\n",
    "                                          batch_size=1000, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA9EAAADOCAYAAAA5WIGgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAk3ElEQVR4nO3de3xU9Z3/8c8kJAPBMAiUhMjFcFdxUZBblIDtkpZWXaAXVlvR1gsgWJDdUim7C9gWLPyKtA1oQUVtS6HtglqLSlYwIBBFDJWCoChEFEJWhSTcAknO/mENP/r9HDxnciYzc87r+XjMH3nPl5PvGT7fnPnOJJ8JWZZlCQAAAAAA+Fwp8Z4AAAAAAADJgk00AAAAAAAOsYkGAAAAAMAhNtEAAAAAADjEJhoAAAAAAIfYRAMAAAAA4BCbaAAAAAAAHGITDQAAAACAQ2yiAQAAAABwiE00AAAAAAAONYvVgZcsWSILFiyQw4cPyxVXXCGLFi2SoUOHfu6/q6+vl0OHDklmZqaEQqFYTQ+wZVmWVFdXS05OjqSkRPc6U7T1L8IaQHx5Uf8iXAOQvLgGIMiofwSZq/q3YmDlypVWWlqatWzZMmv37t3WlClTrJYtW1plZWWf+28PHjxoiQg3bnG/HTx4sMnrnzXALVFu0dZ/Y9cA9c8tUW5cA7gF+Ub9cwvyzUn9hyzLssRjgwYNkn79+snDDz/ckF122WUyatQomTdv3gX/bWVlpbRu3Vquk69KM0nzemrA56qVs/KKrJVjx45JJBJx/e8bU/8irAHEV2PrX4RrAJIb1wAEGfWPIHNT/57/OveZM2dk+/btcv/995+XFxQUyJYtW4zxNTU1UlNT0/B1dXX13yeWJs1CLB7Ewd9fVorm14jc1r8IawAJphH1L8I1AD7ANQBBRv0jyFzUv+eNxT766COpq6uTrKys8/KsrCwpLy83xs+bN08ikUjDrVOnTl5PCWgybutfhDUAf+EagCDjGoAgo/4RJDHrzv2PO3jLstRd/YwZM6SysrLhdvDgwVhNCWgyTutfhDUAf+IagCDjGoAgo/4RBJ7/One7du0kNTXVeMWpoqLCeGVKRCQcDks4HPZ6GkBcuK1/EdYA/IVrAIKMawCCjPpHkHj+TnR6err0799fioqKzsuLiookLy/P628HJBTqH0HHGkCQUf8IMuofQRKTz4meNm2a3HrrrXLNNdfIkCFDZOnSpfL+++/LhAkTYvHtgIRC/SPoWAMIMuofQUb9IyhisokeO3asfPzxx/LAAw/I4cOHpU+fPrJ27Vrp0qVLLL4dkFCofwQdawBBRv0jyKh/BEVMPie6MaqqqiQSichw+Rda2yMuaq2z8rI8I5WVldKqVasm//6sAcQT9Y+gYw0gyKh/BJmb+o/JO9EAkExS27ZR83/d/KaRVde3UMeumP01Nb/oDyXRTwwAAAAJJ2YfcQUAAAAAgN+wiQYAAAAAwCE20QAAAAAAOMQmGgAAAAAAh9hEAwAAAADgEN25AQTG6RsHqvnF/16m5m+f7mBkv9s6RB3be/cxNa93NjUAAABXXjy0o9HHGDppvJpnrHm10cf2M96JBgAAAADAITbRAAAAAAA4xCYaAAAAAACH2EQDAAAAAOAQm2gAAAAAAByiOzcA36n9Yn81f3rJIjVvldJczQf8eJKR9XxkqzqWLtxA4zS7JEfNT/zTJWoefn5bLKcDAAlvXFm+kT3VZaOrY2xa/Gs175Y/wci631fi6th+xjvRAAAAAAA4xCYaAAAAAACH2EQDAAAAAOAQm2gAAAAAABxiEw0AAAAAgEN0576A1Mt7qnlNdqaal91Z1+jvOazrPjXv16pMze+OHDCyFAmpY/N3fkPND7/zBTXPXVNrZM2On1HHyms79RyIsVAz88dYs5lH1LF2XbhLavRjhyutqOcFwJ7Wibvtn46rY1d3XqLmN10ywNM5IXaa5XYxshHP/VUdO6n1u66O3eulu4zMOpWqju34ov7e0UUv6M9hQulpRlZ3rNLF7BAkLx7aoeZaB+398y9Tx2asedXV9zwypMr8flvN7yfivmv3u2MfMbJuYnbsFglm127eiQYAAAAAwCE20QAAAAAAOMQmGgAAAAAAh9hEAwAAAADgUOAai536l4Fq3nra+0b2o06/V8deHa5X8xSb1yTqxRzvZqz78frY9Veu0o99pc2xx5jHPlKnd2D60u9/oOZdf7hVzQGvvP1QfyN7p/fDro4x59vfVfNWW4PXKANoCvtvv9TIVnf+RdNPBE1i9w+zjGx163fUsfqzIHtvfenXRmb7nOkG/ejffm+kmve8qMLInt1/pTq2w0Ppap7+ziE1ry3XG2Ai8WVtbeVqvNbQa6jojcW8oDUbExHp9pDeFExrIGbHbuzQjePV3G2jtGTCO9EAAAAAADjEJhoAAAAAAIfYRAMAAAAA4BCbaAAAAAAAHGITDQAAAACAQ4Hrzn2yXaqaT+mwxcgGhEPq2Hqb1x5SRB+vvVYxq+Jqm7Hu/PlAHyO78dK/eXLsCW3Nx+SS1Ax17O7vFKr59X2/qeatvvWRkdVV6d0EESypPbup+Xs/0WvvrWvN2qu1OfbAbePUPPtVb9YMgiG1e66af3Rdtpq3W73LyIL+8+6ivP81MruOykh+WZ0/afQx1hxvr+ajlQ7abv2u6/OOx85qv13NU1bq9bvwk95q/ujOa40sI0P/BJSOE/THjw7f8aF1276QcWX5RhaPrtXd79M/cWTcYHN+Iu7Oc9Nis0u+iMhQ8W/Xbq5YAAAAAAA4xCYaAAAAAACH2EQDAAAAAOAQm2gAAAAAABxiEw0AAAAAgEOB687d9rGtav74i0ON7CejuujH2HW60fNI3fBGo48hIpIju41su0evjdw9YKKRVXVrqY7tNdXsPisisuHKP6r58D+YXbsv+kqwu9UGTUqG3m37W89uUvNbM8ttjmR23P/qnpvUkdmj3nI0N+BCJjz/opqPzKhW8yFpk42s7TL9WhRk9VKv5htOXdTEM4HXNvVdZWT6/7bINb+Youad/3RIzR/v0NrILJsPSzmVHVbzozcfV3OrNKIfSNHslJ7X5VWqeZsXWhhZ5F2bY3x81PE84J19Dw22uWeHmmpduEVEjgxJ7Oe3tvPTl5wrudP1511H1jT+2PHGO9EAAAAAADjEJhoAAAAAAIfYRAMAAAAA4BCbaAAAAAAAHGITDQAAAACAQ4Hrzm2n9oMPjax9oZkFibVtp5G1fkfvVPnO7V9Q85ROeovMn/c0u3bPkv4uZodkF0pPU3P7Lty6PltuM7Ku/35MHVvr6siAN9LGVJjhsqafRzw0u7SzmhdevsLxMSa+NE7Ne8q2qOaE2PngR3k292w3kldr9GvAJUV6N+va9w6oeeg9JbOZhf75IiIt/2RzhwdSH9OfN9Ud089TY3k1Gdg6OXqQkb079hFXx9hccrmad5eSqOYUb91WTTAyt4/JU102qvnQ0eONLGPNq66OHW+8Ew0AAAAAgENsogEAAAAAcIhNNAAAAAAADrGJBgAAAADAIdeNxTZu3CgLFiyQ7du3y+HDh2XNmjUyatSohvsty5I5c+bI0qVL5ejRozJo0CBZvHixXHHFFV7OG01h4JVGNHK53iDg7tb/o+b1Nq/TfPfJe42ss2xxMbn4oP698+ETOa7GH7dq1DzyjNkqprbMbIqHxgta/R+7dYiaD2ux1eZfpKvpyXVZRtZK3o12WknFCuuPSV89VvV69LR+7Ggm1EhBWwNudfzS+2qeFko1so3He6tjrdJdns4p3tw0EEt0fq7/Q/l27eiCq/t9ZkO0bmI2GxNx33Bs0+JfG9mX11zl6hjx5vqd6BMnTkjfvn2lsLBQvX/+/PmycOFCKSwslG3btkl2draMGDFCqqurGz1ZIN6ofwQZ9Y+gYw0gyKh/4BzX70SPHDlSRo4cqd5nWZYsWrRIZs6cKWPGjBERkSeffFKysrJkxYoVMn682c4cSCbUP4KM+kfQsQYQZNQ/cI6nfxO9f/9+KS8vl4KCgoYsHA7LsGHDZMsW/Vd1a2pqpKqq6rwbkIyiqX8R1gD8gfpH0LEGEGTUP4LG0010eXm5iIhkZZ3/919ZWVkN9/2jefPmSSQSabh16tTJyykBTSaa+hdhDcAfqH8EHWsAQUb9I2hi0p07FDr/j/MtyzKyz8yYMUMqKysbbgcPHozFlIAm46b+RVgD8BfqH0HHGkCQUf8ICtd/E30h2dnZIvLpq1EdOnRoyCsqKoxXpj4TDoclHA57OY3AS728p5rvmXixmv/bl9aq+d2RJ4wsRfQfhHZduC//3WQ17zon8TtxuxVN/YsEYw2c/ef+RvZCv1/YjM5Q06EP/Zuad/it/2opGfmx/mv1UpSMkN5a2u7nY9rxePSRTgz7Zpvd80VEUpRrxpG6U+rY1I/1pkS10U8rJvy4Btyacan+fOKsVWdkFzc7oY6tH6p3xbezX2kWnLa3hTr2dMezat79N86rKfW13Wpu1eifIBEUyV7/1w7W/1/d0LpZ+43tOY5t2nkkAk/fic7NzZXs7GwpKipqyM6cOSPFxcWSl5fn5bcCEg71jyCj/hF0rAEEGfWPoHH9TvTx48dl3759DV/v379fduzYIW3atJHOnTvL1KlTZe7cudKjRw/p0aOHzJ07VzIyMuSWW27xdOJAPFD/CDLqH0HHGkCQUf/AOa430a+//rpcf/31DV9PmzZNRERuu+02eeKJJ2T69Oly6tQpueeeexo+aH3dunWSmZnp3ayBOKH+EWTUP4KONYAgo/6Bc1xvoocPHy6WZf/3XaFQSGbPni2zZ89uzLyAhET9I8iofwQdawBBRv0D58SkOzcAAAAAAH7kaXduxM578/Vuld8csdnIbor8Xh17dbhezbUuqSIi9aKN18f2+tMkNe8xfauaI1gqp5jdddun6q2PH6vqqOYdn9ij5mbPV/f+d6K+vurTzG7Lrffp3V3Da7d5MBPES2rbNkbWfPQRdaz+s1Gk9Ix+7PbFFUbmRd0mg135j6u59hiOfvN76tg2773t6ZyQGO6MvKfnK5c2+tgpw908rxGRrzk/9h1lI9T8k7suVfO6XXudHxxx81SXjfGeQlIbV5av5m4e16ytrdT8yJCqqOYUa7wTDQAAAACAQ2yiAQAAAABwiE00AAAAAAAOsYkGAAAAAMAhNtEAAAAAADhEd+4E8+EP89R8z7cL1bxezM/rSxGzo/CnY/XXTOzGa6+x2I3d+43Faj5g371qnvPk34ysrioxu+/BucPT9PotuXqRkR239P7EK6d8Vc3TPn5dzSu/PdjIPrnhlDr2mSEPq3nPtDfUXHPcqlHzEX8dp+ZfuEOv69pyvfMz4uOtBV2NbM+Ver3YuWPJFDXPeXtLVHNKJqH+V9jcs71J54Hg+eXR3kb2x7J+6thNV61o9Pd7rEuRmr/ybHM1v+v5O9X8soXmNaD2vQNRzwvO7HvIfM7wqR1NOQ3f2VxyuX6Hi+7cdp28h44er+YZa151fOxY4J1oAAAAAAAcYhMNAAAAAIBDbKIBAAAAAHCITTQAAAAAAA7RWCzBtN1dq+ZaA7FP83oju+zlu119z7R3Wqj52R5mc6b6j9PVsW99XW98tu3+X6l5rx73GFmP78e3QQAa7+Q1J9U8HDJ/1HxSd0Ydm7ZObyD28V1D1Lzov35uZK1S9AYvInp+3+FBav6dtmZDqP7pYXXs1qtWqXmJTU+pn478lpHV7d2nD4ZnUrvnqnnxPy9SUv3/2k7OguA2ELtn5ZpGHzv9yYsbfQwkj1dr0tR84qPm8wMRkUt/d1DNrarjRtamer86dswlY9T8VK8sNS/7innt2j1Wf16T31y/pr01Wn9+NHfoVUb2+g36z6fagx+oORLLydH6c4l4N8BqCt3vK9HvGNv4Yx/K15sad2/8ZadReCcaAAAAAACH2EQDAAAAAOAQm2gAAAAAABxiEw0AAAAAgENsogEAAAAAcIju3Amm+Z9fU/Mb/tzf8TG6SalX03Hspu8PUPMf79+m5nu/vsTILj89WR3bdfrW6CeGmAiF9a7Fq4YstfkX5o+aL27RO7DmyptqPvBuva61Tty9fzdJHXvpX06reermnWo+u9d3jOzw8Lbq2M0zFqn54LD+Y/aTh8wsckOqOlbq6/Qcrq3dqLfzPGvpn1KgmVVxtZq/vUz/OeiFtIv0zr+78h83Mrv5zWmvr6EU0Tuf6p8KsV2foA27Y39z3w1GdtEf/d/B1u9SQuYnhoiIbK0xf7b99JZx6tiONh9poH92iTu1ZXqH7zSbvPs6Mxu5doI6dv1Tj6n5Wf3DVeQ/2pnXur63D1PHdvox3bm9EsQu0vHUbZW+Xt4d+0gTz8R7vBMNAAAAAIBDbKIBAAAAAHCITTQAAAAAAA6xiQYAAAAAwCE20QAAAAAAOER3bsTU1Bn3qvmGn//KyL45YrM6djuv9SScE1+9Ss2vSnfeSb37A3qnbC/6UNd30I+dUqx3J7Zpnip1u/YaWftd+tivb75dzZ/7y2/VfHDWASPbk6J39rT0hreIwrqTaWo+rMVZx8eY1V7vUD3nq3p91Uvj/wNTbH4Oase+KaLPY8qha22Ora+AF1/qZ2S1bfQeyXu+Zn7iwmdH13zw265G1laO2BwDyWLavIlq3nbnCTMs0T+JIdGl/Y++/n9QrnfF/0mW/qkrmlMdnf8cgrfGleUb2VNdNsZhJkgW7E4AAAAAAHCITTQAAAAAAA6xiQYAAAAAwCE20QAAAAAAOMQmGgAAAAAAh+jOfQG1X+yv5mV36v2D607qD2fPO1/3bE7JJnNliZrP+aH52P64/Q517Jevv1PNUze8EfW80DgV/Tx4/e2s3uXXzt9++k9q/quflhvZmqEPq2NHLZ6i5j2n2nTtPnvG4exETme3VPPUkP5YPbfnSiPrVqvPA975f+NuUfOTT/63kY3MOOrJ9yytMWvg3t03uzrGRYURx2PDR06quVVq01reRlcxu+0ff8Hsqn0hR+pOqXnmQXfrH8mh7TLnn9CQSJpdkqPm7911qZH1uv5ddezcrLVqbtebX/u50PtX1a6OAe9sLrncDF1257528G4153MH/Il3ogEAAAAAcIhNNAAAAAAADrGJBgAAAADAITbRAAAAAAA4RGOxC7h07l41X9tpg5qPGTRKzQPdPmWg2TxJRGRCW7PxU720iPVskMRaPP2amq/bbTap+9Mv+6lj3xmlNxy78YGvqHndkQojCw3Qa3rOkmX6MSy9JUz6W9R7PIS2/FXNlw3PN7KlLb35PwrVmA3q2pS97cmxNVbMjixyddsPXY1/o6a9mqe/sM2L6QCqlMxMNd//gz5q/p9j/6Dm37joGTff1cVYkZtfmmBkPf/GuoiXnI3KT86x7o7xlE0jsnFbzevLkSFV7g6OhMM70QAAAAAAOMQmGgAAAAAAh9hEAwAAAADgEJtoAAAAAAAcYhMNAAAAAIBDdOf+uw9/mGdkazsVqmPrbV57qP3AXdfSIDjxwAk175Bqdr09UndKHdvs+Fk1j2UHWlxY13l6h+MD406q+aXNMoys58r31bElDw1W84t3HFNzrQ4yp6tD5Wtpt6r53lkt1bxNp9ZGtqXf4+rYOkuvyAkfXK/mucsPGFmgO/nHWe2Hh+I9haSQEtK7zafYXBevTDc73IuI1A81296mbCqNfmIIrA9+ZD5/e+zOX6ljrw6vj9k8vlv2JTUvW9BLzXs/b15H9dWFppCx5lUj65ZvdlAXEXl37COujq117R46erzjeSSDfTbP3dw+Vpru95U0+hixwDvRAAAAAAA4xCYaAAAAAACH2EQDAAAAAOAQm2gAAAAAABxiEw0AAAAAgEOuunPPmzdPVq9eLXv27JEWLVpIXl6e/OxnP5Nevc51HrQsS+bMmSNLly6Vo0ePyqBBg2Tx4sVyxRVXeD55L7XdbfbFrbfp/1xP/0TTwCvVeMOVT6i59hiO3H63OjZn286op+UlP9e/W/Un9S7cNzyut8Ve9d2FRvbz7Nf0g//MJk8YqWqaV/qvat7uxrdtjqN3rk9krAE8/9I1ar7gO1vUvGMz85MYRESO9m5uZG03RT+vpkD9Ry+1bRsjq+3dWR1bNfO4mm/qu8rm6NuNJC2k/5yurD+j5qU1+ic0fPfFO43ssh/tVcfWHTum5hmid1tOxmeSQVsDdl2hxw3OV3OtC7edTYt/rR97un7szSWXq7mbztUnRw9S80P5ITW/dvBuI7M/xx2O52HnyzlXNfoYTcnVO9HFxcUyadIkKSkpkaKiIqmtrZWCggI5ceLck8H58+fLwoULpbCwULZt2ybZ2dkyYsQIqa6u9nzyQFOi/hF0rAEEGfWPoGMNAOe4eif6hRdeOO/r5cuXS/v27WX79u2Sn58vlmXJokWLZObMmTJmzBgREXnyySclKytLVqxYIePHm5+JVlNTIzU1NQ1fV1VVRXMeQMzFov5FWANIHlwDEGRcAxB0XAOAcxr1N9GVlZUiItKmzae/prN//34pLy+XgoKChjHhcFiGDRsmW7bov+o1b948iUQiDbdOnTo1ZkpAk/Gi/kVYA0heXAMQZFwDEHRcAxBkUW+iLcuSadOmyXXXXSd9+vQREZHy8nIREcnKyjpvbFZWVsN9/2jGjBlSWVnZcDt48GC0UwKajFf1L8IaQHLiGoAg4xqAoOMagKBz9evc/7/JkyfLm2++Ka+88opxXyh0/h+oW5ZlZJ8Jh8MSDoejnQYQF17VvwhrAMmJawCCjGsAgo5rAIIuqk30vffeK88++6xs3LhROnbs2JBnZ2eLyKevRHXo0KEhr6ioMF6VSjTN/2x2BJ4wc5g69pFOxWr+4/3b1HzGHROMrNl6s5tkwlE6br9za4Y6dO+YJWpu1yGz5LSZ5czTxyYaP9a/VzrP0X9da/Jfv29kH4yqc3XsZUOfUPPhzc+6Oo6mok7vNn7t+ilG1vkP+i/wtHvhjUbPI1mwBoKr5y/L1PwrG+5xdZysneZxzM/ISEzUvz1rSF81L3jUfN40sfWLro5t18361Zo0I1v58WB17Bu/uErNI7/VOxz3FPO5obsrlz8FfQ0cGaL/3Xa3h8zn+yIi7459xPGxbbtf2+VjHR9avOig7da4Mr3b+P75lxmZXSf7ROXq17kty5LJkyfL6tWrZf369ZKbm3ve/bm5uZKdnS1FRUUN2ZkzZ6S4uFjy8vK8mTEQJ9Q/go41gCCj/hF0rAHgHFfvRE+aNElWrFghzzzzjGRmZjb8fUMkEpEWLVpIKBSSqVOnyty5c6VHjx7So0cPmTt3rmRkZMgtt9wSkxMAmgr1j6BjDSDIqH8EHWsAOMfVJvrhhx8WEZHhw4efly9fvlxuv/12ERGZPn26nDp1Su65556GD1lft26dZGZmejJhIF6ofwQdawBBRv0j6FgDwDmuNtGWZX3umFAoJLNnz5bZs2dHOycgIVH/CDrWAIKM+kfQsQaAc6Luzh0Eby3qo+b1P9+g5len639i/sCjy4zsL5VXqWP/fED/nuHnImre9rGtRtasU0dlpEjZLZ3VPGXIUTV/rt/DRtYhtYU6tt6m5YfWQExE5D/uuMvIUl8LTmOmoGnxtNmcpcfT7o4xX8xGd5/msdNDkqABINCEaj88pObpNrntcbyYDGKu9ov91fxQvt5NufiOBWp+cUpzI7NrFFZaoz+XuvklvWlTryVmY0irdJc6NiJ6AzHAC93v0+vry/ddZWT7HtKb37lpQuYVu+Zfms0ll6u53bmL6E3Ykq2JmCbqz4kGAAAAACBo2EQDAAAAAOAQm2gAAAAAABxiEw0AAAAAgENsogEAAAAAcIju3BeQuVLvNHfTe9/T8ydeVvO7IweMbHD7HerYOe1L1TxlYEjN6x8wP24gRfQu1/WifzRBitgcW8xO3EfqTqljl3ycp+bbx/dVczpxAwDQ9JpdkqPmb83oZGTjh61Xx05ts9vm6OlqOqviaiN74dfXqmOzXtE/MaTn37ap+ed/6BKQeNx08o49vYO2pjsd7hvwTjQAAAAAAA6xiQYAAAAAwCE20QAAAAAAOMQmGgAAAAAAh9hEAwAAAADgEN25o/HaTjX+y+BcNX903I1G9r2Jf1HH3t16n8031V/vqJf6Ro61H5//5reMLPJfGepYa5v+mIjY5QAAoKl99MUuav7W6F86PkbfzfqnlHSdc0b/Bx+WG9EXjm1Vh9o9UwGARMI70QAAAAAAOMQmGgAAAAAAh9hEAwAAAADgEJtoAAAAAAAcYhMNAAAAAIBDdOf2UF1VlZq3L9xiZM8VXqyOfU4GeDqnxmgl7xqZFYd5AAAAb7T+jd4V+6bfOH/+0cXmkzfqopoRACQf3okGAAAAAMAhNtEAAAAAADjEJhoAAAAAAIfYRAMAAAAA4BCbaAAAAAAAHGITDQAAAACAQ2yiAQAAAABwiE00AAAAAAAOsYkGAAAAAMAhNtEAAAAAADjEJhoAAAAAAIfYRAMAAAAA4BCbaAAAAAAAHGITDQAAAACAQ2yiAQAAAABwqFm8J/CPLMsSEZFaOStixXkyCKRaOSsi52qxqbEGEE/UP4KONYAgo/4RZG7qP+E20dXV1SIi8oqsjfNMEHTV1dUSiUTi8n1FWAOIL+ofQccaQJBR/wgyJ/UfsuL1UpON+vp6OXTokGRmZkp1dbV06tRJDh48KK1atYr31GKmqqrK9+eZTOdoWZZUV1dLTk6OpKQ0/V88BG0NJFNtRCuZzpH6b1rJVBuNkUznyRpoWslUG9FKpnOk/ptWMtVGYyTLebqp/4R7JzolJUU6duwoIiKhUEhERFq1apXQD7hXgnCeyXKO8Xj19TNBXQOcY+Kg/pteEM5RJHnOkzXQ9DjHxEH9N70gnKNIcpyn0/qnsRgAAAAAAA6xiQYAAAAAwKGE3kSHw2GZNWuWhMPheE8lpoJwnkE4x1gIwuPGOcJOEB63IJyjSHDO02tBeNw4R9gJwuMWhHMU8ed5JlxjMQAAAAAAElVCvxMNAAAAAEAiYRMNAAAAAIBDbKIBAAAAAHCITTQAAAAAAA6xiQYAAAAAwKGE3kQvWbJEcnNzpXnz5tK/f3/ZtGlTvKcUtY0bN8qNN94oOTk5EgqF5Omnnz7vfsuyZPbs2ZKTkyMtWrSQ4cOHy65du+Iz2SjNmzdPBgwYIJmZmdK+fXsZNWqU7N2797wxfjjPpuKn+hfx/xqg/r3npzXg9/oXYQ14zU/1L+L/NUD9e4v6T77aCNoaSNhN9KpVq2Tq1Kkyc+ZMKS0tlaFDh8rIkSPl/fffj/fUonLixAnp27evFBYWqvfPnz9fFi5cKIWFhbJt2zbJzs6WESNGSHV1dRPPNHrFxcUyadIkKSkpkaKiIqmtrZWCggI5ceJEwxg/nGdT8Fv9i/h/DVD/3vLbGvB7/YuwBrzkt/oX8f8aoP69Q/0nZ20Ebg1YCWrgwIHWhAkTzst69+5t3X///XGakXdExFqzZk3D1/X19VZ2drb14IMPNmSnT5+2IpGI9cgjj8Rhht6oqKiwRMQqLi62LMu/5xkLfq5/ywrGGqD+G8fPayAI9W9ZrIHG8HP9W1Yw1gD1Hz3q3x+14fc1kJDvRJ85c0a2b98uBQUF5+UFBQWyZcuWOM0qdvbv3y/l5eXnnW84HJZhw4Yl9flWVlaKiEibNm1ExL/n6bWg1b+IP2uD+o9e0NaAX2uDNRCdoNW/iD9rg/qPDvX/KT/Uht/XQEJuoj/66COpq6uTrKys8/KsrCwpLy+P06xi57Nz8tP5WpYl06ZNk+uuu0769OkjIv48z1gIWv2L+K82qP/GCdoa8GNtsAaiF7T6F/FfbVD/0aP+z0nmcw7CGmgW7wlcSCgUOu9ry7KMzE/8dL6TJ0+WN998U1555RXjPj+dZywF8XHyyzlT/94I2mPlp/NlDTReEB8nv5wz9d94QXyc/HTOQVgDCflOdLt27SQ1NdV4VaKiosJ49cIPsrOzRUR8c7733nuvPPvss7Jhwwbp2LFjQ+6384yVoNW/iL9qg/pvvKCtAb/VBmugcYJW/yL+qg3qv3Go/3OS9ZyDsgYSchOdnp4u/fv3l6KiovPyoqIiycvLi9OsYic3N1eys7PPO98zZ85IcXFxUp2vZVkyefJkWb16taxfv15yc3PPu98v5xlrQat/EX/UBvXvnaCtAb/UBmvAG0GrfxF/1Ab17w3q/1PJWBuBWwNN1sLMpZUrV1ppaWnWY489Zu3evduaOnWq1bJlS+vAgQPxnlpUqqurrdLSUqu0tNQSEWvhwoVWaWmpVVZWZlmWZT344INWJBKxVq9ebe3cudO6+eabrQ4dOlhVVVVxnrlzEydOtCKRiPXyyy9bhw8fbridPHmyYYwfzrMp+K3+Lcv/a4D695bf1oDf69+yWANe8lv9W5b/1wD17x3qPzlrI2hrIGE30ZZlWYsXL7a6dOlipaenW/369WtokZ6MNmzYYImIcbvtttssy/q07fusWbOs7OxsKxwOW/n5+dbOnTvjO2mXtPMTEWv58uUNY/xwnk3FT/VvWf5fA9S/9/y0Bvxe/5bFGvCan+rfsvy/Bqh/b1H/yVcbQVsDIcuyLG/e0wYAAAAAwN8S8m+iAQAAAABIRGyiAQAAAABwiE00AAAAAAAOsYkGAAAAAMAhNtEAAAAAADjEJhoAAAAAAIfYRAMAAAAA4BCbaAAAAAAAHGITDQAAAACAQ2yiAQAAAABwiE00AAAAAAAO/R8uOdFGPno8rgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 1200x400 with 5 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "batch = next(iter(train_loader))\n",
    "x = batch[0][:10]\n",
    "y = batch[1][:10]\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(12, 4))\n",
    "\n",
    "for i in range(5):\n",
    "    axs[i].imshow(x[i].squeeze().numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MLP vs ConvNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_size, n_hidden, output_size):\n",
    "        super().__init__()\n",
    "        self.input_size = input_size\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(input_size, n_hidden),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(n_hidden, n_hidden),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(n_hidden, output_size),\n",
    "            # nn.Softmax(dim=-1)\n",
    "        )\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, self.input_size)\n",
    "        return self.net(x)\n",
    "\n",
    "class ConvNet(nn.Module):\n",
    "    def __init__(self, input_size, n_kernels, output_size):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=1, out_channels=n_kernels, kernel_size=5),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "            nn.Conv2d(in_channels=n_kernels, out_channels=n_kernels, kernel_size=5),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(n_kernels * 4 * 4, 50),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(50, 10),\n",
    "            # nn.Softmax(dim=-1)\n",
    "        )\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, perm=torch.arange(0, 784).long(), n_epochs=1):\n",
    "    model.train()    \n",
    "    optimizer = torch.optim.AdamW(model.parameters())\n",
    "    \n",
    "    for epoch in range(n_epochs):\n",
    "        for i, (data, target) in enumerate(train_loader):\n",
    "            # send to device\n",
    "            data, targets = data.to(device), target.to(device)\n",
    "            print\n",
    "\n",
    "            # permute pixels\n",
    "            data = data.view(-1, 28*28)\n",
    "            data = data[:, perm]\n",
    "            data = data.view(-1, 1, 28, 28)\n",
    "\n",
    "            # step\n",
    "            optimizer.zero_grad()\n",
    "            logits = model(data)\n",
    "            print(logits.dtype, logits.shape, targets.dtype, targets.shape)\n",
    "            \n",
    "            loss = F.cross_entropy(logits, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            if i % 100 == 0:\n",
    "                print(f\"epoch={epoch}, step={i}: train loss={loss.item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, perm=torch.arange(0, 784).long()):\n",
    "    model.eval()\n",
    "    \n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    \n",
    "    for data, targets in test_loader:\n",
    "        # send to device\n",
    "        data, targets = data.to(device), targets.to(device)\n",
    "        \n",
    "        # permute pixels\n",
    "        data = data.view(-1, 28*28)\n",
    "        data = data[:, perm]\n",
    "        data = data.view(-1, 1, 28, 28)\n",
    "        \n",
    "        # metrics\n",
    "        logits = model(data)\n",
    "        test_loss += F.cross_entropy(logits, targets, reduction='sum').item()\n",
    "        preds = torch.argmax(logits, dim=1)     \n",
    "        correct += (preds == targets).sum()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    accuracy = correct / len(test_loader.dataset)\n",
    "    \n",
    "    print(f\"test loss={test_loss:.4f}, accuracy={accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameters=6.442K\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "epoch=0, step=0: train loss=2.3063\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "epoch=0, step=100: train loss=1.1308\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "epoch=0, step=200: train loss=0.6939\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "epoch=0, step=300: train loss=0.4806\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "epoch=0, step=400: train loss=0.4875\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "epoch=0, step=500: train loss=0.3774\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "epoch=0, step=600: train loss=0.4050\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "epoch=0, step=700: train loss=0.3826\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "epoch=0, step=800: train loss=0.4308\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "epoch=0, step=900: train loss=0.2527\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([64, 10]) torch.int64 torch.Size([64])\n",
      "torch.float32 torch.Size([32, 10]) torch.int64 torch.Size([32])\n",
      "test loss=0.3561, accuracy=0.8968\n"
     ]
    }
   ],
   "source": [
    "# MLP\n",
    "input_size = 28*28  \n",
    "output_size = 10   \n",
    "\n",
    "n_hidden = 8\n",
    "mlp = MLP(input_size, n_hidden, output_size)\n",
    "mlp.to(device)\n",
    "print(f\"Parameters={sum(p.numel() for p in mlp.parameters())/1e3}K\")\n",
    "\n",
    "train(mlp)\n",
    "test(mlp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameters=6.422K\n",
      "epoch=0, step=0: train loss=2.3233\n",
      "epoch=0, step=100: train loss=0.4378\n",
      "epoch=0, step=200: train loss=0.1719\n",
      "epoch=0, step=300: train loss=0.2179\n",
      "epoch=0, step=400: train loss=0.3929\n",
      "epoch=0, step=500: train loss=0.2728\n",
      "epoch=0, step=600: train loss=0.1188\n",
      "epoch=0, step=700: train loss=0.0639\n",
      "epoch=0, step=800: train loss=0.2077\n",
      "epoch=0, step=900: train loss=0.1554\n",
      "test loss=0.1042, accuracy=0.9667\n"
     ]
    }
   ],
   "source": [
    "# ConvNet, with the same number of parameters\n",
    "n_kernels = 6\n",
    "convnet = ConvNet(input_size, n_kernels, output_size)\n",
    "convnet.to(device)\n",
    "print(f\"Parameters={sum(p.numel() for p in convnet.parameters())/1e3}K\")\n",
    "\n",
    "train(convnet)\n",
    "test(convnet)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The convolutional network performs better with the same number of parameters, thanks to its use of prior knowledge about images:\n",
    "\n",
    "* Use of convolution: Locality and stationarity in images\n",
    "* Pooling: builds in some translation invariance\n",
    "\n",
    "What if those assumptions are wrong?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MLP vs ConvNet, on shuffled pixels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA9EAAADOCAYAAAA5WIGgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgyklEQVR4nO3de3yU5Zn/8WsSwhAwCQclk0CEVIJY8AQqJYuAtaRi9SfSg6u+LO66rRyFF62IS38VbX9EoaZYAyrUArqluLtycF3WJbtqkAKKEYWCgCCHCKQRxCRySEhy//5QQuNzjT4zmeNzf96v1/zBd+48cz/hujNzz5Nc4zPGGAEAAAAAAF8rJd4TAAAAAAAgWbCJBgAAAADAJTbRAAAAAAC4xCYaAAAAAACX2EQDAAAAAOASm2gAAAAAAFxiEw0AAAAAgEtsogEAAAAAcIlNNAAAAAAALrGJBgAAAADApXbROvCCBQtk7ty5cuTIEenfv7/MmzdPrr322q/9uubmZjl8+LBkZGSIz+eL1vSAoIwxUldXJ7m5uZKSEt77TOHWvwhrAPEVifoX4TkAyYvnANiM+ofNQqp/EwXLly83aWlpZtGiRWbHjh1mypQpplOnTubAgQNf+7WVlZVGRLhxi/utsrIy5vXPGuCWKLdw67+ta4D655YoN54DuNl8o/652XxzU/8+Y4yRCBs8eLAMHDhQnnrqqZbskksukdGjR0txcfFXfm1NTY107txZhsqN0k7SIj014Gs1yhlZL2vk008/laysrJC/vi31L8IaQHy1tf5FeA5AcuM5ADaj/mGzUOo/4r/O3dDQIBUVFTJjxoxWeVFRkWzYsMExvr6+Xurr61v+XVdX98XE0qSdj8WDOPjibaVwfo0o1PoXYQ0gwbSh/kV4DoAH8BwAm1H/sFkI9R/xxmJHjx6VpqYmyc7ObpVnZ2dLVVWVY3xxcbFkZWW13PLy8iI9JSBmQq1/EdYAvIXnANiM5wDYjPqHTaLWnfvLO3hjjLqrf/DBB6WmpqblVllZGa0pATHjtv5FWAPwJp4DYDOeA2Az6h82iPivc59//vmSmprqeMepurra8c6UiIjf7xe/3x/paQBxEWr9i7AG4C08B8BmPAfAZtQ/bBLxK9Ht27eXQYMGSVlZWau8rKxMCgsLI/1wQEKh/mE71gBsRv3DZtQ/bBKVz4meNm2a3HXXXXLVVVfJkCFDZOHChXLw4EEZN25cNB4OSCjUP2zHGoDNqH/YjPqHLaKyib7tttvk2LFj8sgjj8iRI0dkwIABsmbNGunVq1c0Hg5IKNQ/bMcagM2of9iM+octovI50W1RW1srWVlZMkJuobU94qLRnJHXZbXU1NRIZmZmzB+fNYB4ov5hO9YAbEb9w2ah1H/UunMDAAAAAOA1bKIBAAAAAHCJTTQAAAAAAC6xiQYAAAAAwCU20QAAAAAAuMQmGgAAAAAAl9hEAwAAAADgEptoAAAAAABcYhMNAAAAAIBLbKIBAAAAAHCJTTQAAAAAAC6xiQYAAAAAwKV28Z4AALRFyoB+jmzvTL869v1hi9U81ae/n/hW/RlH9j91A9Sxy/ZcpeY9xmxXcyDRTd+7zZE9vOf/qGPTv7sv2tMBACBhcCUaAAAAAACX2EQDAAAAAOASm2gAAAAAAFxiEw0AAAAAgEtsogEAAAAAcInu3ACSWu7vP3JkL+WVq2Obgxyj2TSp+ZXtne8zXtlthzr2/iD5lSt+rOZ07UaiODSjUM2HdtjsyJqNL9rTAQAg4XElGgAAAAAAl9hEAwAAAADgEptoAAAAAABcYhMNAAAAAIBLNBYDEBfH7hmi5i/8cm5Ix8lJba+kqSEdo2jHGDX/a02GI3tvyNKQjv2HK5eo+f+Vq0M6DhAtJf+0SM1TeJ8dADyv+dor1fzDnzqzvdcvVseO/uC7al4/uYv+mFt3uptcAuMZEgAAAAAAl9hEAwAAAADgEptoAAAAAABcYhMNAAAAAIBLbKIBAAAAAHCJ7txnpTi7+bbrnReHiYTG1NY5sqajx+IwEyA0F/zpPTW/6Ybxar61cImarz3VyZHdt/F2deyF/6J37e5Q/hd9/KWZznCVOjSoy7Xm4SJy6IFCR9bjsQ2hHRwIwaEZzpoTERna4a0gX+FcL4d3dVdHFsiH4U4LAPA1fGn6i4mUPr0c2e5fnKeO/cll69X8ns6lap6V0sGRnTH6/P6tzxo1X/yCvpdadYvzE1qadu/VD56guBINAAAAAIBLbKIBAAAAAHCJTTQAAAAAAC6xiQYAAAAAwCU20QAAAAAAuER37i989MBgR/bOpCfiMJPQzD9+sSP7738aqg/etDXKswHcaz55Us173613+R2d/X39QPUNjqjPoS2hzSVI7gvpKLqUIO9VNvkjcHAgBPWd9baqaT69a72m7/OfqXmQhq2wTGrnrCB3uK+xY99zvq4REfn4mmA/qduu7+IgdV2xPWqPCWgarx+k5v0f01/DP56zvM2P+bMjI9T81YN9HdldBfqnOUzrulPN/yGzUs2f73+zI+tId24AAAAAALyJTTQAAAAAAC6xiQYAAAAAwCU20QAAAAAAuMQmGgAAAAAAl+jO/YVT/U47sl99PDCkY7xzPE/NPznV0ZE1G73v7/kdT6j5zgM5al5x/ZOObOKLu9SxVy6Youb5zx1U88bKj9QciKbmE/oaaP5QzyOhXSBbzfeMPq/Nxz7e7PzZIiLSZXf0Os0Cmlm3/mtI4yvqnVnKSWc3fBGRpnAmhITXrkeumu+a2kvNV/+wRM37prV3/ZhpvlfV/IyJXpWtLOqq5g8vudOR9Vr0gTq26eOPIzoneEfdbd9yZB/for82WFVYqubB1tCWBudriTs3/kQde/5/dFDzzivfVfMep53d6ZetukodG6w7t5dxJRoAAAAAAJfYRAMAAAAA4BKbaAAAAAAAXGITDQAAAACASyE3Flu3bp3MnTtXKioq5MiRI7Jy5UoZPXp0y/3GGHn44Ydl4cKFcvz4cRk8eLDMnz9f+vfvH8l5R1zB2Hcc2WZJDfEoh9U0K4QjBGubURDk2N/6fz9zZA//YLk69r0JziZkIiJ98+9V82/+2vkeS+N+vQmZLbxa/7Zr6KM37vvL3XqDj1BMq/yemmcs39TmY8ca9Z8c/npfoZrfmbFAzc8Y/Ti3rx3vyPrueCvseXlBvNdAu2/0dmTv/7KbOrbfz/eruS8zQ813jXf+HPzD959Sxw7xB3u1ojc/2n3G2ZBu06l8dWzJjuvV/Ad93lXzVUuGB5mLU6/RH6r5M/n/rubvTnS+bnrurh7q2BeHXarmXmo4Fu/6TxQpHfQGXTvnD1DzN4sed2RdUvRjzP/0EjW/7fkb1bzHnDcd2Tea31XHBhNSm9PyLnp+dUgP6QkhX4k+ceKEXH755VJaqr+4nDNnjpSUlEhpaals3rxZAoGAjBw5Uurq6to8WSDeqH/YjPqH7VgDsBn1D5wT8pXoUaNGyahRo9T7jDEyb948mTlzpowZM0ZERJYuXSrZ2dmybNkyufde/YonkCyof9iM+oftWAOwGfUPnBPRv4net2+fVFVVSVFRUUvm9/tl+PDhsmHDBvVr6uvrpba2ttUNSEbh1L8IawDeQP3DdqwB2Iz6h20iuomuqqoSEZHs7OxWeXZ2dst9X1ZcXCxZWVktt7y8vEhOCYiZcOpfhDUAb6D+YTvWAGxG/cM2UenO7fP5Wv3bGOPIznrwwQelpqam5VZZWRmNKQExE0r9i7AG4C3UP2zHGoDNqH/YIuS/if4qgUBARD5/Nyon51yXx+rqasc7U2f5/X7x+/2RnIZVes/c6Miee07vbNmw+nU13zlK7755/xXODq97f9hLHdu470CQGdojnPoXYQ14xdGmU2q+98l+ap4pyded+6tQ//GRkuHstHznT/9bHdtkQurBKpf8xtlVOFhPZsRmDdz4coUjW5W1Xx37z2uvUvMRmX9W86L0E67mICKysKa3mj+95GY1D7zp/PmYUr5FHdtTtqv5JknTjy3Bf1X4y+p/q+dF909X8/+aPMeR/TjzkDq2eMZoNe/71Hlq3rRnnz6ZJGXTc0DlsovUfPfgZ9R85PY7HVnqY3pXfX/FHjXv8an7Oo+mnNdr9Dt+Htt5JIKIXonOz8+XQCAgZWVlLVlDQ4OUl5dLYaH+kRuAV1D/sBn1D9uxBmAz6h+2CflK9GeffSZ79px7l2Tfvn3y7rvvSteuXeXCCy+UqVOnyuzZs6WgoEAKCgpk9uzZ0rFjR7njjjsiOnEgHqh/2Iz6h+1YA7AZ9Q+cE/Im+u2335brrruu5d/Tpk0TEZGxY8fKkiVLZPr06XLq1CmZMGFCywetr127VjKUXzkDkg31D5tR/7AdawA2o/6Bc0LeRI8YMUKMMUHv9/l8MmvWLJk1a1Zb5gUkJOofNqP+YTvWAGxG/QPnRKU7NwAAAAAAXhTR7txIDE279M5+y8d8W81Pv7hezefmODsB3vqHW/QHvU6PgUSSesEFav7ZL2rbfOz3GvROm5l/8lYXbiSWlM5Zjmxql9eDjNY/ZuaHe27Uhx/9JLxJIWry0o65Hjs7+201/6hR/ySBcZU3OLK3Vlymz+MPu9Q892hidBAOVe5cfd7f7na/I9t+V6k69v2/n6/mF7efoOYFk73VnduL2vXSP7N6wRV/VPNr3nZ24RYRCdx+0JE1n9yvjk2kT0Bol+/8RJ4hS52fEPBVgv28Sa86HdacEglXogEAAAAAcIlNNAAAAAAALrGJBgAAAADAJTbRAAAAAAC4xCYaAAAAAACX6M79hcZvD3JklT9pVMd22NypzY+X8VGzmp/3r9Hr5Nu0Y7eaF28apeZji552ZHf3+LM6dknf7+iPuXuvy9kB0ffBEz3UfMelz7o+xoSPhqn5h//cT83bSWidLIFQvP/r7m0+RsOYM2re9GlNm4+NyHp69M2O7PF8Z4f2r5L2mf7aJvX1dxxZruhdqxOpg3A0FZQ6uyqv/YH+GrAo/US0p4MYqx2Uq+ZD/PoKSFvZRc2bT+6M2JzcSBmgvx6pvKmrmp+4uF7Ny69/wpHlpKaHNJfv/X66mudtTM5O/n+LK9EAAAAAALjEJhoAAAAAAJfYRAMAAAAA4BKbaAAAAAAAXKKx2Bdqe7d3ZNuGORtriYiI3lcoJJ806X/Ev6NYbxDywKM/VfNuiza2eS79njil31HkjG7pdFQd+ttBF6h5Jo3FEAem8HI1Xz5kYZCvSHV97Lef14/d/X+Tv0kGElf1pEI1/+A7pUrqU8cWrBqv58feDHdaiLGm7bscWYftcZiIJRo/OuTISvYrL45EpOiSldGeDhLcXfevUfOSa0c6ssxtzn1HOE53M47slR/PVcf2bBdaUzCRUMc75S8+oOZ6e8PkwpVoAAAAAABcYhMNAAAAAIBLbKIBAAAAAHCJTTQAAAAAAC6xiQYAAAAAwCW6c3+h6+JNjuzWFd+JyLE/uqe/I3tk3HPq2FEdj6t52S8fV/Oixp85stqL9Hn0/oXeyTul9qT+BUASaB5+pSPrX7JNHXtZe/dduEVE/utkhiPLeeWIOrYppCMDoTnz7Ro1bxZnZ1YA0dM4L6Dmf12gf9LJRd88HM3pIIo6rnpbzW+YcKua/67PC2o+7oYPlYOEPS0X9K7aD1U7Xy+JiPxdxm41L0o/4chOmgb9GE869yMiIj0Otf1ThBIVV6IBAAAAAHCJTTQAAAAAAC6xiQYAAAAAwCU20QAAAAAAuMQmGgAAAAAAl+jOfZZxdjht+lTvhhqqnMc3OLJn/udmdewvHtLf16gYvETNN/y61PU80v5R70x8xlQE+Qr377E03PGJmn9QONj1MYK57LL9ar74GyvV/O/zCtv8mEg8qd26qvnHP3N2l58beDOkYx9t0ruqznhusiPL2+Ncz0Ay6PfQHjWnszzgXoeX31LzPU9kqvmafqvU/CYZFKkpIVqa9Z+O7b5zUM1/ftk/qPnOic5P+vClR+Ynb9abHRxZt+2n1bGp695T80/fvELNi9L/7Mg2nu6sju3xmH2vjbgSDQAAAACAS2yiAQAAAABwiU00AAAAAAAusYkGAAAAAMAlNtEAAAAAALhEd+44aX7vfTXvMUYfP/i+KWreMKzWkb141UJ1bJ80vz4XadYfNAQbBv5Rv2Ngmw8tG0/r875l0lQ1Txe9cyaS2/5nctX8vUFL23zsKQdvUfO8X9nXbRLxFawL/RWBQ66P8eTxAjVvOnosrDkBAL5e89adat733hhPJIgP5wxR85dz9U/6ef/MGUdWPHWsOraDha+9uRINAAAAAIBLbKIBAAAAAHCJTTQAAAAAAC6xiQYAAAAAwCUaiyWJ7N8FaXD0O2d039AJ6tDqqzuq+UW3fqDmf7pojSP7pKleHfsfJ/rq8wviN6ucjZzSq33q2Nzfb1Pz9Dr7mhgkumBNkXaW9HZk52Wd0g+yrosaPz9oXrBH/fqJfaHv6vFqfsm8YA2XaMSE2Dp248VqvrqX3vhF8/TL31XzfNkY1pwAfL0Un96kNUX01zZAtDTccLWav/Sjx9X8YKNeu5OmTHNk6S/z2vssrkQDAAAAAOASm2gAAAAAAFxiEw0AAAAAgEtsogEAAAAAcIlNNAAAAAAALtGd24NS1r+r5jmb9P/ufSP6qLnWifv6hdPVsXm/DtI9PIhQusTqPQMRTyfHDFbzfg/8Rc1X91zk/uB6U0kJpQt3MJf2P6jm28f3VvN+c044Q397dWzjfv3YgCa1i96F/uIJ20M6zpEmZ5f7Ps/pXeWbQjoygFA0G/26VDMrD1GUen43RzZ23mp1bJ80v5pftmiyml+4OrTX9rbhSjQAAAAAAC6xiQYAAAAAwCU20QAAAAAAuMQmGgAAAAAAl9hEAwAAAADgUkjduYuLi2XFihWyc+dOSU9Pl8LCQnnsscfk4osvbhljjJGHH35YFi5cKMePH5fBgwfL/PnzpX///hGfPEKTkpWp5puuel7Nl9b2dWShduH2EhvrPyUjQ827TDmg5gt6rovmdNrssV4r1Lz8ggI1n1d7iyNLr9aP3X2+97tz27gGouX01Rep+bMXPhPacYzPkTXt2B3WnPDVqH+EY3tDY7ynEDGsgfhJzdRfw2euNo7szowj6th7Dl6n5r2KK9TceWT8rZCuRJeXl8vEiRNl06ZNUlZWJo2NjVJUVCQnTpz7GJg5c+ZISUmJlJaWyubNmyUQCMjIkSOlrq4u4pMHYon6h+1YA7AZ9Q/bsQaAc0K6Ev3KK6+0+vfixYule/fuUlFRIcOGDRNjjMybN09mzpwpY8aMERGRpUuXSnZ2tixbtkzuvfdexzHr6+ulvv7c5xHX1taGcx5A1EWj/kVYA0gePAfAZjwHwHY8BwDntOlvomtqakREpGvXriIism/fPqmqqpKioqKWMX6/X4YPHy4bNui/BlxcXCxZWVktt7y8vLZMCYiZSNS/CGsAyYvnANiM5wDYjucA2CzsTbQxRqZNmyZDhw6VAQMGiIhIVVWViIhkZ2e3Gpudnd1y35c9+OCDUlNT03KrrKwMd0pAzESq/kVYA0hOPAfAZjwHwHY8B8B2If0699+aNGmSbN26VdavX++4z+dr3ejEGOPIzvL7/eL3+8OdBhAXkap/EdYAkhPPAbAZzwGwHc8BsF1Ym+jJkyfLSy+9JOvWrZOePXu25IFAQEQ+fycqJyenJa+urna8K4XYO/PNXmr+SVO9mpcuGu3IcsTe7txn2VT/6f/ZQc1fuOg/YzwTkZOmQc3nHRvkyJa/OEIdm//vR9U8WDfjXtS7yqY1kOg+ONMt3lOwDvWPUIx5+T41L5A3YzyTyGENRE9qt65q3n5Fmpo/33uNI/tjXY4yUuTo2AvU3NTvdTk7/K2Qfp3bGCOTJk2SFStWyKuvvir5+fmt7s/Pz5dAICBlZWUtWUNDg5SXl0thYWFkZgzECfUP27EGYDPqH7ZjDQDnhHQleuLEibJs2TJZvXq1ZGRktPx9Q1ZWlqSnp4vP55OpU6fK7NmzpaCgQAoKCmT27NnSsWNHueOOO6JyAkCsUP+wHWsANqP+YTvWAHBOSJvop556SkRERowY0SpfvHix3H333SIiMn36dDl16pRMmDCh5UPW165dKxkZGRGZMBAv1D9sxxqAzah/2I41AJwT0ibaGPO1Y3w+n8yaNUtmzZoV7pyAhET9w3asAdiM+oftWAPAOWF350by2fuj9mr+dyt/ruYFJTRVst2vLnwpyD16LYXis2a9od3ANVPUvNvb+o+rbos2OrILgzQEa3I5NyDRNEuzms9Y+I+OLJeGeEDU1N32LTUf0P7PMZ4Jkllqly5qvvO3vdV8V59Fal7TfNqR/cv4m/TH3P2Ou8nBlbA/JxoAAAAAANuwiQYAAAAAwCU20QAAAAAAuMQmGgAAAAAAl9hEAwAAAADgEt25LTLwir1qXvXERTGeCZLF5J9OVvPnfj9PzbNT09W878vjHNk3Hzmkjz202d3kAItc+sY9ap4/h07cQCx9Y8pONc9K6RDjmSCZffBAPzXfdX1pSMe5/jf3O7LAazwvxAJXogEAAAAAcIlNNAAAAAAALrGJBgAAAADAJTbRAAAAAAC4xCYaAAAAAACX6M5tkZM3n1HzTp++GeOZIFmkrX1bze+5cGhIx+krzo7bjWHNCPCGYGvrph6D1DxftkZzOgAUdbd9y5H9Me/xIKP17twXL6pR8+ZwJ4Wk0y4n4Mjmfv/5kI4x8K271Dz3dxvDmhPajivRAAAAAAC4xCYaAAAAAACX2EQDAAAAAOASm2gAAAAAAFxiEw0AAAAAgEt057ZI06d6h0gAAAC0Vt/Z58iyUvQu3E8eL9APsudgJKeEZJTivGa57K+D1aEFPV9Wc98bnfVjGxPurNBGXIkGAAAAAMAlNtEAAAAAALjEJhoAAAAAAJfYRAMAAAAA4BKNxQAAAIA2eG7hDWqefXJDjGeCRNN46LAjqxmqj50mQ9Q8R6ijRMOVaAAAAAAAXGITDQAAAACAS2yiAQAAAABwiU00AAAAAAAusYkGAAAAAMAlunMDAAAAX9LtL6cc2S3X/Ugdm/3BxmhPB0AC4Uo0AAAAAAAusYkGAAAAAMAlNtEAAAAAALjEJhoAAAAAAJcSrrGYMUZERBrljIiJ82RgpUY5IyLnajHWWAOIJ+oftmMN4CzTeNoZNtWrY5vMmSAHSa7/ROofNgul/hNuE11XVyciIutlTZxnAtvV1dVJVlZWXB5XhDWA+KL+YTvWAGTT6njPIG6of9jMTf37TLzeagqiublZDh8+LBkZGVJXVyd5eXlSWVkpmZmZ8Z5a1NTW1nr+PJPpHI0xUldXJ7m5uZKSEvu/eLBtDSRTbYQrmc6R+o+tZKqNtkim82QNxFYy1Ua4kukcqf/YSqbaaItkOc9Q6j/hrkSnpKRIz549RUTE5/OJiEhmZmZCf8MjxYbzTJZzjMe7r2fZugY4x8RB/ceeDecokjznyRqIPc4xcVD/sWfDOYokx3m6rX8aiwEAAAAA4BKbaAAAAAAAXEroTbTf75eHHnpI/H5/vKcSVTacpw3nGA02fN84RwRjw/fNhnMUsec8I82G7xvniGBs+L7ZcI4i3jzPhGssBgAAAABAokroK9EAAAAAACQSNtEAAAAAALjEJhoAAAAAAJfYRAMAAAAA4BKbaAAAAAAAXEroTfSCBQskPz9fOnToIIMGDZI33ngj3lMK27p16+Tmm2+W3Nxc8fl8smrVqlb3G2Nk1qxZkpubK+np6TJixAjZvn17fCYbpuLiYrn66qslIyNDunfvLqNHj5Zdu3a1GuOF84wVL9W/iPfXAPUfeV5aA16vfxHWQKR5qf5FvL8GqP/Iov6TrzZsWwMJu4l+4YUXZOrUqTJz5kzZsmWLXHvttTJq1Cg5ePBgvKcWlhMnTsjll18upaWl6v1z5syRkpISKS0tlc2bN0sgEJCRI0dKXV1djGcavvLycpk4caJs2rRJysrKpLGxUYqKiuTEiRMtY7xwnrHgtfoX8f4aoP4jy2trwOv1L8IaiCSv1b+I99cA9R851H9y1oZ1a8AkqGuuucaMGzeuVdavXz8zY8aMOM0ockTErFy5suXfzc3NJhAImEcffbQlO336tMnKyjJPP/10HGYYGdXV1UZETHl5uTHGu+cZDV6uf2PsWAPUf9t4eQ3YUP/GsAbawsv1b4wda4D6Dx/1743a8PoaSMgr0Q0NDVJRUSFFRUWt8qKiItmwYUOcZhU9+/btk6qqqlbn6/f7Zfjw4Ul9vjU1NSIi0rVrVxHx7nlGmm31L+LN2qD+w2fbGvBqbbAGwmNb/Yt4szao//BQ/5/zQm14fQ0k5Cb66NGj0tTUJNnZ2a3y7OxsqaqqitOsoufsOXnpfI0xMm3aNBk6dKgMGDBARLx5ntFgW/2LeK82qP+2sW0NeLE2WAPhs63+RbxXG9R/+Kj/c5L5nG1YA+3iPYGv4vP5Wv3bGOPIvMRL5ztp0iTZunWrrF+/3nGfl84zmmz8PnnlnKn/yLDte+Wl82UNtJ2N3yevnDP133Y2fp+8dM42rIGEvBJ9/vnnS2pqquNdierqase7F14QCARERDxzvpMnT5aXXnpJXnvtNenZs2dL7rXzjBbb6l/EW7VB/bedbWvAa7XBGmgb2+pfxFu1Qf23DfV/TrKesy1rICE30e3bt5dBgwZJWVlZq7ysrEwKCwvjNKvoyc/Pl0Ag0Op8GxoapLy8PKnO1xgjkyZNkhUrVsirr74q+fn5re73ynlGm231L+KN2qD+I8e2NeCV2mANRIZt9S/ijdqg/iOD+v9cMtaGdWsgZi3MQrR8+XKTlpZmnn32WbNjxw4zdepU06lTJ7N///54Ty0sdXV1ZsuWLWbLli1GRExJSYnZsmWLOXDggDHGmEcffdRkZWWZFStWmG3btpnbb7/d5OTkmNra2jjP3L3x48ebrKws8/rrr5sjR4603E6ePNkyxgvnGQteq39jvL8GqP/I8toa8Hr9G8MaiCSv1b8x3l8D1H/kUP/JWRu2rYGE3UQbY8z8+fNNr169TPv27c3AgQNbWqQno9dee82IiOM2duxYY8znbd8feughEwgEjN/vN8OGDTPbtm2L76RDpJ2fiJjFixe3jPHCecaKl+rfGO+vAeo/8ry0Brxe/8awBiLNS/VvjPfXAPUfWdR/8tWGbWvAZ4wxkbmmDQAAAACAtyXk30QDAAAAAJCI2EQDAAAAAOASm2gAAAAAAFxiEw0AAAAAgEtsogEAAAAAcIlNNAAAAAAALrGJBgAAAADAJTbRAAAAAAC4xCYaAAAAAACX2EQDAAAAAOASm2gAAAAAAFz6/1GZoeKjpbZzAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 1200x400 with 5 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA9EAAADOCAYAAAA5WIGgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAu9ElEQVR4nO3de3xU9Z3/8c9kAuFiMhaQBEKAAAERFAQVRRS6rSguAtrW60Ol2+0Py6Uiblmt64q9gOIWaQW8LQW7P69bQah1VVoliIgiQlGQe4AQiAjVJHJJmMnZP1yDw/kcOWfmnJkzc17PxyN/5JOT73zP5POdM99M8p6QYRiGAAAAAACAU8pJ9wQAAAAAAMgUbKIBAAAAALCJTTQAAAAAADaxiQYAAAAAwCY20QAAAAAA2MQmGgAAAAAAm9hEAwAAAABgE5toAAAAAABsYhMNAAAAAIBNbKIBAAAAALAp16uB582bJw899JDs379f+vTpI7Nnz5ZLLrnklN/X2Ngo+/btk/z8fAmFQl5ND7BkGIbU1dVJx44dJScnsd8zJdr/IqwBpJcb/S/CNQCZi2sAgoz+R5A56n/DA88995zRrFkz48knnzQ2bdpk3H777Ubr1q2N3bt3n/J7KysrDRHhg4+0f1RWVqa8/1kDfPjlI9H+T3YN0P98+OWDawAfQf6g//kI8oed/g8ZhmGIywYNGiQDBgyQRx99tKnWu3dvGTNmjMyYMeMbv7empkZOP/10GSJXSq40c3tqWWf3tAtMtS7T3lOPDeXqf3hgRKNq/bktH5hq1/ca4GB21rwcO1lROS4r5RX5/PPPJRKJOP7+ZPpfJDvXwJFR55lqrZa+72iMF7asV+vX9upvqlVPHKQe2/kfd6n1hhGfOJqLGz4dZ57jGY+/62gMJ/eJXcn2vwjXgHSrud58XRARiTynXxsylRf9L+Kfa8DgC6dKbm5e3NdyVm5wPB/4U3RpJ1Mtd9TeNMwknl/6//xnxkluq+ZxX8sbs8fxfHCCV4+ZidDm4tY8cvr2NNUaN27TDz5pG+yk/13/c+6GhgZZu3at3HXXXXH14cOHy6pVq0zH19fXS319fdPndXV1/zexZpIb4gnUqeS0aGGqWd1voZDFJtriz2UK8s1/xuDWz8TLsZP2f+spkT8jctr/IsFYA7nN7PepFa1nrMYJ55lvT0SkWevmar0xDfdzuHlq7xPbkuh/Ea4BfqD1loiPHmNd4kn/i/jnGpCbJ7m58T/LnCz7GQZa6zxTyRdr1C/936q55J50H/ni/slgnj1mJsDLfUBO2Ly2rJ/nGeqndvrf9WCxgwcPSiwWk8LCwrh6YWGhVFdXm46fMWOGRCKRpo+SkhK3pwSkjNP+F2ENILtwDUCQcQ1AkNH/CBLP0rlP3sEbhqHu6u+++26pqalp+qisrPRqSkDK2O1/EdYAshPXAAQZ1wAEGf2PIHD9z7nbtWsn4XDY9BunAwcOmH4zJSKSl5cneXnml9014Tc7mos/qDfXRCR28JCtMd30vY8PqPUXe7f37DZL737HVFtatUY9dlTx+Y7G1o5vvORc9dict9YlPbZjOWG93hhLfuwEOe1/EWdr4OWqtaba60dbq8f+rseZtsZ00xevdlPrp13h7H99NSOLB6p17T4ZWayPsfhO8//ii4hcKfb/Hz98lvl/bUREYpu22h5DRKT9HP1P25ywuk+cOPnxorauUYp6JT6el9cAv2tZrq/xF3v8j6l2ZbF3GRCn/5f5upCN3Oh/t7l5DchZucH059v6453/7gcv3LNzvVr/dbf+ppp2P4n4677K/a53/987YZv5ejS3TL92ucnN/v9Njz/KaSf9ye/iDebnoKuuO1sdN/axxf/AWoh+x9wbLTZV6cfu119Vt7LrlxeZal3vdedx2sljghv9H+7ZXa3Htu5wNI6Xa7Fxw2bbx568dzv6RVSWm2N8VK6/Et28eXMZOHCgLFu2LK6+bNkyGTx4sNs3B/gK/Y+gYw0gyOh/BBn9jyDx5H2ip0yZIjfffLOcd955ctFFF8kTTzwhe/bskdtuu82LmwN8hf5H0LEGEGT0P4KM/kdQeLKJvu666+TQoUPyi1/8Qvbv3y99+/aVV155Rbp06eLFzQG+Qv8j6FgDCDL6H0FG/yMoPNlEi4iMHz9exo8f79XwgK/R/wg61gCCjP5HkNH/CIKQYZz0LtNpVltbK5FIRIbJaFvvF3ZgiR6e1PGOY2o9unNXMtNLSOW/mf8PpORXyYcKOXX+ej1wa81teliYrN7g4Wzsy+l/llpvXL/Jk9uLGsdluSyRmpoaKSgo8OQ2vonTNbD94QvVes979J9f45Ejav21fetNtcs79j/l7duR27WzqRbd5SxUJROCYrxide7nvnuLWi++ZmPCt5Vp/R8UQe5/p5IN3/LLGlj1UQdTsNKu6LdMx6cjRDJT7XjIHPAkItLYslGtl01MPhQzHUID+5hqxlp71wW/9P/Wjwsl/6T+v7nkYtPxe6bp/2vdeVrqn2d7KWzxs4jV1iY9tt+vL07n1+t983OHLecdt3VbTvrfs7e4AgAAAAAg27CJBgAAAADAJjbRAAAAAADYxCYaAAAAAACb2EQDAAAAAGCTZ29xlSrtR29W69EUz+ObuJHEPXuXPsbkUnNSoVgErt97xgdqfdTqcMLzclP1HXrCYtHD3iUs/nDLblPtyBcxWT7As5tMyn9Vvm2q3VyiH6tnjYrM271SrV/ecYip5lZio9MkbsTr986tar3k+x+leCYQSU+S6SOflXk2thM7H9DTjbvd9Y5nt+n0/nbyc/jkp+brTqz+mMhjS2yP4ZWf972AhHqX9Xxgi1qPHfq77TG2zRmk1v2U5G03iVvE/NwrVn9MZG76+///9b7Q1P968r53c7B67LHi5TXAjRRuK07mnY7r35lv/LNa7yHr1LrdJO5k8Uo0AAAAAAA2sYkGAAAAAMAmNtEAAAAAANjEJhoAAAAAAJvYRAMAAAAAYFPIMCyinNOktrZWIpGIDJPRpFLakFvSyVSLVu5Vj615pYdaj1y53dU5pdu2R8zJmWvHPKwee32JOZk1ahyX5bJEampqpKCgwPX5nYpbayAdCYpesjqfMQOvNNWi1Z94PR3bMu3nkC39HxR6Wq0/eytTsAYyy4+2Vphq83uWqsf66fHYydrVntd8k7JJiSeF0/+ZJadFC1Ot8dgx9di66y5U6/nPr3Z1TpnMSf/zSjQAAAAAADaxiQYAAAAAwCY20QAAAAAA2MQmGgAAAAAAm9hEAwAAAABgU266J+AXS6vWmGqjis9Pw0ycsUri1rT+j4hn88jt1lWtR3fusj2G9jMQcf5z0FIpr59kTuH2u+2/HSA5LeNTF3uO0+8jjRtpo89VrlLrWqq5iLfJp9Zj+CeJW+P3pOSTf2a1dY3SvleaJuMzx4efp9abvf6+Wk9H8q/f+8vvwt/6lqlmGA0in6VhMjaQxm5WF2tp+1g/3VdO5vLXUb9R67d1GeLWdKDQHh9ERGKf+ecBYumOt001q97yewr30dEXqPWWS95L8Uzs4ZVoAAAAAABsYhMNAAAAAIBNbKIBAAAAALCJTTQAAAAAADYRLPZ/nIRX/f2fLlLrbX7/jlvT8UTuG3rojaMxLALEjPkN+jd82/44o4oTm1O26nH7B5IbanbK4xqWdVHrzS/bnfQcrALEMtUvK/RgtntL7a9/p+FRBybo92H7uXpomxPzdq9U6+MdhM1saIjFff5FQ2NSc3LL7gV9JKdVfLBe6fUbTMfduX2j+v2/6dEn6TlYBYhZGX3ZDRZf2Zr0XIIsJz9frTfW1SU99uZf9TSPe/SYyJ1JD+0JPwVjpdq+qfpj6Qu9k38sdcOsXfpzwCld9eeMmtiwAWr9Nv0yn9Ve2LJeCvLjX+vT+r/xknPV7895a13Sc3AaIHbD5n1q/dkzOyY9l9wuJWp9pIPnzjnnnKnWGzdstj1GqFlztW4ct9gHOJCOALFkwlV5JRoAAAAAAJvYRAMAAAAAYBObaAAAAAAAbGITDQAAAACATWyiAQAAAACwybfp3H8fe4GEm8cns7Z90h/p135K4V5apacNa5wkkFuJ7tyl1mufuFCtF4ieVGg1jhML9+jJxGM7208mzgavn7VIrY+U1Ke4+j051kkKtxWn5/jezx/Rx5lrHme+RU//yKKnrVK4tQRxq3n/vPSCuM+jxnER2asem0qtVp9mugZopvznj9V6saQ+sTe2KbtSuHNLOplq0crU94bTFG4n/V824V1TLWocl0pHt4hUyN9t/50Dwm3bqPXYob+r9W2PDFLrZZPM/WGlp0VqsRPh5R8kPUa2GPrbcRLOi78GFCmP626kcLvFjRRuK9Hd+qPSQ7tWm2o/66o/J3eSwm3FKoXb6TuX+MXJ8/vyOZC9d7fhlWgAAAAAAGxiEw0AAAAAgE1sogEAAAAAsIlNNAAAAAAANrGJBgAAAADAppBhGEa6J/F1tbW1EolEZJiMltxQs4TH2fWri9R613+zn6z9o60Van1+z9KE5vR1e+8erNY7zUh9oiziRY3jslyWSE1NjRQUFKT89p2uga2P64nTo8/TEys/Hhi1PZfO77ZW63sGHbY9Rjo4TYmctUt/XJjSVX8c8Tur89doyZR+7X8n52XF70mhbjg+/Dy13uz191M8E//Y9ls9rbbsdnOyrZ/XgBvqrtPvi/znzffFT7frab6/63Gmo9s8OM78WNrucf+800m2ufbjalPthd5Ftr7Xz/2f09fcd1umtlLH2fKdJ9W6k2tApiZOB131HeY91vv/YvFOKUk8B+KVaAAAAAAAbGITDQAAAACATWyiAQAAAACwiU00AAAAAAA2sYkGAAAAAMCm3HRPwImlVWtMtVHFejKxVQq3NoaIyNkr/tlUm9/TweQccprCndOvt1pv/NvHbkwnaeGz9DsrtmlrimeS3XJam9Oye47Te9qNzpjXaYVaHyn+TqZ0mpz5/d/fqdY7i/11Gj7jDLUe+/RT22M4TQJtvUK/Te34nc/0V4/dVDU/7vPaukZp3+sbJplGJKLak44U7leqPlDrVxYPsD1G4Tt6EuonF9UmNKev2/b9eWr9ytvtz88PPr/FnHJ9+h+cpVxPn/6EWp/x/DmmmtMUbitaEvflH+k/19f6OkuE1h43vXysCPfsrtZjW3d4dptO2U3izjQNhebnQGW3WFw3HT5POXLNIPMYxY6G8NS2P+iPVWW36I+9Thy52nzuIiKtFr+b9NjpUPSw+bnbuXmT1GOLHTzPOxmvRAMAAAAAYBObaAAAAAAAbGITDQAAAACATWyiAQAAAACwKWQYhuHkG1asWCEPPfSQrF27Vvbv3y+LFy+WMWPGNH3dMAy5//775YknnpDPPvtMBg0aJHPnzpU+ffrYGr+2tlYikYgMk9GSG2rm6GSSEW7bxlSLHfq7K2OnOvQCyYkax2W5LJGamhopKIgPOPG6/0VOrIEDW7pIQX7877m87JvZu8zhCpO7DnY0RqfVp6n1x0rKTTUvz8VpQFcQHB9+nlo/OYTKL/3/7QF3SW64RdzXjPc/sj2On6T6GnD41W5qvfUVOz27zWzilzWQ6udBSI5xcX+1Hnp7fUrnkSy/9P+/r/6utDgtvv/fONscLOYWLdjQjVBDEZHYMHMoWHh58oFgXvvJtu2m2qNlPdIwk9T5pv4/meNXog8fPiz9+vWTOXPmqF+fOXOmzJo1S+bMmSNr1qyRoqIiueyyy6Surs7pTQG+Q/8jyOh/BB1rAEFG/wMnOH6LqxEjRsiIESPUrxmGIbNnz5Z77rlHrrnmGhEReeqpp6SwsFCeeeYZGTduXHKzBdKM/keQ0f8IOtYAgoz+B05w9X+iKyoqpLq6WoYPH95Uy8vLk6FDh8qqVfr7cNXX10ttbW3cB5CJEul/EdYAsgP9j6BjDSDI6H8Ejaub6OrqahERKSwsjKsXFhY2fe1kM2bMkEgk0vRRUlLi5pSAlEmk/0VYA8gO9D+CjjWAIKP/ETSepHOHQqG4zw3DMNW+cvfdd0tNTU3TR2VlpRdTAlLGSf+LsAaQXeh/BB1rAEFG/yMoHP9P9DcpKioSkS9/G9WhQ4em+oEDB0y/mfpKXl6e5OXlJXybS6vWqPVRxec7GseNJG6ruYxU5uLWvL2kzdFP8/ObRPpfxHoNXNurv61kVreSqJ0mcWv2XviFPhdJbSq203O3ug/HXPI9Uy26c5d6bE6LFmq98dgxR3Nxg54G7e1tut3/xgcfi2Gj/8OnR9R67POaU35vqqQ6FT4dKdyL976n1q/udEGKZ5K+a5fba8CuigcuUuuld72T8Jj40rY5g0y1sonvqsdmWgq329zu/xUXtrb1HCi3tItaj1bsVuvpePcOLYk7HfPYt/gstd7x6k1q/Y2a3kr1uHrs05Vvq/WbSi62NbdM5Oor0aWlpVJUVCTLli1rqjU0NEh5ebkMHpz8E3TAz+h/BBn9j6BjDSDI6H8EjeNXor/44gvZvv3E+4ZVVFTI+vXrpU2bNtK5c2eZPHmyTJ8+XcrKyqSsrEymT58urVq1khtvvNHViQPpQP8jyOh/BB1rAEFG/wMnON5Ev//++/Ltb3+76fMpU6aIiMitt94qCxculKlTp8rRo0dl/PjxTW+0/vrrr0t+fr57swbShP5HkNH/CDrWAIKM/gdOcLyJHjZsmBiGYfn1UCgk06ZNk2nTpiUzL8CX6H8EGf2PoGMNIMjof+AET9K5AQAAAADIRq6mc6eDn9Kir+7zXYuvmFNi/TTvT36qBz6M8jjN1y6r9N3jfUvVes7K9R7Oxn8umjZRrbcV/ySz6mnR3iVQOk0K7vfOrWr9b289ZapZzdsvKdwi+hzv2P6xeuzDPbT0zcyxeONf1LqfHmP97oxVp6v1BV3+aqpZ9X86UritUqn9cu1KVk6rlpITah5XazxyxHScVQq3G+m/53ygvzXRhgHWr0Z6pdPq09S61btCuMEqiduJfVP151gdZ65KemwrVj97TarfPSAZOa1amWovrVykHmt1Xp6eb05YLb9caX5Oko773SqF28r2y7U/w9ffzcgqhduNx6HoXzqr9cb/0FPfd/3A/Pi09YrHk57HyXglGgAAAAAAm9hEAwAAAABgE5toAAAAAABsYhMNAAAAAIBNbKIBAAAAALDJt+ncz235QAry4/f4T9V2MR33g9O2q99/fYmehujEF9deqNZPe2G1Wo99bk7hzgSFv0t9QuR1O64w1a44Y6N67IsW4cFVQ80pjSIiJSvtzS0T7b7f3Ndd7nPn57dwj/mOG9t5iCtje5lCueffzffJ1Z2cjVHy/Y/U+khJft5Wa2D0t6811WJbdzga28n9mukp3FYGvjtWrReL/ngCs08Hf67W3eh/L9139Qtq/Q93laR4Jt5oPHJUGkPRhL/f6ePujt8oz3kG6M930mH7r85S6y1EfzcGv3CSwv1K1Qdq/criAWrdjeTjk9+hJVZ/TOSxJba/P5W0dHqntj6uv3PD1pGPmWqOn7s0xtSyl8+Bbti8z1R79syOrowdO6QncTtx1tP6O8i8uvshU218F/05Z+5396j12hv08+z5I/PjlhfXM16JBgAAAADAJjbRAAAAAADYxCYaAAAAAACb2EQDAAAAAGCTb4PFru81QHJDzU553IvS3rM51HXSf8dwmgtjL61ao9ZHFeuBB07GcTqGlZ9sM4e2PVrWQz02fHpErVuHKXxqqjj9WZb82rtANL9yI0Rs21N6QMnYzkkPbanzu61NtXmdVqjHOg3gOHv4FlOt5heOhkgLpyFi0BVfQ4BYUP2hV3YEiFnJ6dtTcsJ5cbXGDZttf7/T0Knud/ojROzBinfV+r+WpngiaWAVIGbF6me57ZFBplrZJP1+PTlcNmocl02OZpFeVvdBTosWar3nOP35d6qDFN0IhRMRuSl/v6n2rDgLFsst1o9/6b0/mWpO57fppjlqfWRx8sG1Bc/qj1mV95oDZ0t+6f6egVeiAQAAAACwiU00AAAAAAA2sYkGAAAAAMAmNtEAAAAAANjEJhoAAAAAAJt8m85dPXGQhPPik/U6zPIujVlLyRtZrB+bW9pFrUcrdtu+PbcStN0aR6Mlced266oeG925K+nbczq2WwnnmSSvvMhUe7HHn9VjrRIUy279QK1/urSXqXbGKHPytYh1qqSVH+7+jql26b9MUI8tEGcJsTVDDjk6XpPT90y13viROQl3z3+frR7b+QcfqnUnSZZupXXCOf0a4Ox+P/xqN7Xe+oqdCc0Jwdb40VZptPEuJVZ+dfAcF2cTz43HqnAPPW7bjRTuoD+WWiVxB0nlZD3pvPgBfS9xYII50bn9XP3Y0Ll91Pqrf35arR83YmrdDW70dLRqn2djOxnjp9v1dx/4XQ/9OdrIjZ+p9Zf7pObde3glGgAAAAAAm9hEAwAAAABgE5toAAAAAABsYhMNAAAAAIBNbKIBAAAAALDJt+ncLYYdlHCrvPjiLO9ur+/CiaZaV3lHPdZJCreIyGv71ptq9cZx9VirZOltCy2Slsc6S0l2Qku/HmWRWO4GY36DWt+24UK17mQu2ZLkXT+02lQbKe6kjWpJ3LnFHdVjrZLrrbxc9VdTbf6929RjFz97hrPBXaClcIuIhNu2MdU2DF6oHntkr96/13a6yPY83EqOdSNpOmi0++f48PPUY5u9/r5at0rhDrdra6rFDiafKu+1qkXmBNr2j7VUj31tweNqnb5L3Atb1ktBfvxrHU7uz9X9nCV7V99hTicuelhPuXXj5xrbXuHoeKvE7Ss3jzHVnF6j/EJ7vigicnnH/imdhx/smn+25LSKf5eebjeut/39VincVto/ak403/mAfv3udpe+P7D8OeWETaVQM6stWL1atUyc76RcpwzDYmxntv3BnHBedov+Di8tywvV+tGhn9i+va65etq2lZf7fMv2sUeuGaTWWy1KPMmeV6IBAAAAALCJTTQAAAAAADaxiQYAAAAAwCY20QAAAAAA2MQmGgAAAAAAm0KG4VKEm0tqa2slEolI9ZYSUyqllqR8dPQF6jgtl7znyfwywf4p5oRNEZEOs5wlFWrCvXqo9diW7UmP7aVtj+ipfGWTzKl8UeO4LJclUlNTIwUFBV5PzeSrNTBMRktuyFm6aiaxTJr0UZrv4L+ZE7dX9Wuehpnoplfoj3M/L9UfF+2g/0/YN1V/LO04M/nH0kzw0+3m1PpNx/TY47/0zfd6OiZePYYEdQ3kdutqqkV37lKP3fGQnlrc/Wd6ajEyh1/6/8IrfiG5zeLTuVu8nNrn9k9Xvq3Wbyq5OKXzcMqtx8YgvtOHk/7nlWgAAAAAAGxiEw0AAAAAgE1sogEAAAAAsIlNNAAAAAAANvk2WMyrQI3qO/SgmKKHvQuK0cI6XnrrRfVYq3/YX1q1Rq1rYWsL96xUjx3beYjFDLPf1B0fqvWZ3c821fwSqqGtge3/da7p+B43r3M0/tZ5euhUz/HJB3bk5OvhQo11dUmP7RdWgR0X/ftEtd52fmYF7fi5/4FUyLQ1cHCcHvLV7vHMeuw5lfDpEbX+j6t2mmpLz2rr9XSylp/7f8fT5udA3W9y9hxo9i79+f7krvr+QHPkaj2sttVic1itlTu2f6zWH+7R2/YYmUwLQXYjANlK9WSL/d/s+NskWAwAAAAAAA+wiQYAAAAAwCY20QAAAAAA2MQmGgAAAAAAm9hEAwAAAABgU+DSuYFT8XMyZaaqWtTHVCu+ZmMaZuIdq9Ruq8R9v6L/EXTZsgbC7fSE6tjBQwmP6aanK99W6zeVXJzimeDrsqX/gUSQzg0AAAAAgAfYRAMAAAAAYBObaAAAAAAAbGITDQAAAACATWyiAQAAAACwKdfJwTNmzJBFixbJ5s2bpWXLljJ48GB58MEHpVevXk3HGIYh999/vzzxxBPy2WefyaBBg2Tu3LnSp485nTcdllatUeujis9P8UxSz41zP/innmq93VVbE5pTJsmG/hdJT4q0kyTu0Plnq3VjzYduTcc27b6yup8eOnSW19NJu2xZA36WqSnvR18rVestL69I8Uy8k2n9v/m+Hmq994Mt1Hp0b5WpNmqTnuS99Cw9+dsJt1K4tZRvq7HdWF9379ig1md0P8f2GJkqk9ZAbqditf7oqufV+o87DzHVtHcWEfH23UW8vM2dz/RX691uXK/WnTwHCiJHr0SXl5fLhAkTZPXq1bJs2TKJRqMyfPhwOXz4cNMxM2fOlFmzZsmcOXNkzZo1UlRUJJdddpnU1dW5Pnkgleh/BB1rAEFG/yPoWAPACY5eiX711VfjPl+wYIG0b99e1q5dK5deeqkYhiGzZ8+We+65R6655hoREXnqqaeksLBQnnnmGRk3bpxpzPr6eqmvr2/6vLa2NpHzADznRf+LsAaQObgGIMi4BiDouAYAJyT1P9E1NTUiItKmTRsREamoqJDq6moZPnx40zF5eXkydOhQWbVqlTrGjBkzJBKJNH2UlJQkMyUgZdzofxHWADIX1wAEGdcABB3XAARZwptowzBkypQpMmTIEOnbt6+IiFRXV4uISGFhYdyxhYWFTV872d133y01NTVNH5WVlYlOCUgZt/pfhDWAzMQ1AEHGNQBBxzUAQefoz7m/buLEibJhwwZZuXKl6WuhUCjuc8MwTLWv5OXlSV5eXqLTANLCrf4XYQ0gM3ENQJBxDUDQcQ1A0CW0iZ40aZIsXbpUVqxYIZ06dWqqFxUViciXv4nq0KFDU/3AgQOm30qli19SuHO7dVXr0Z27PLvN3osmqvUyedf2GH/pv1CtXy+DE5lSUhbuMT9wi4iMVRIWrbQsN/fl8cMNIiOsvyeT+1/E/8mK6UjhtuLkvio/p6WjsRsuP89Ua/7a+47GSJdMXwN+5vf1acVpCnfVXeZrRvED1n/y7CeZ0v9lk/Rre9TBGG6kcHvNScq3G+vr4hbHkx4j06ViDRgX9BUjNz5JPvTO32x/v5Y2L6KncFvxMoU7HbdplcJtRVsvv6zQ3+nn3lJn+6t/+PCwqfbG2a0djVE9Wd97FM1OzbXE0Z9zG4YhEydOlEWLFskbb7whpaXxb2lRWloqRUVFsmzZsqZaQ0ODlJeXy+DBqd9kAW6i/xF0rAEEGf2PoGMNACc4eiV6woQJ8swzz8iSJUskPz+/6f8bIpGItGzZUkKhkEyePFmmT58uZWVlUlZWJtOnT5dWrVrJjTfe6MkJAKlC/yPoWAMIMvofQccaAE5wtIl+9NFHRURk2LBhcfUFCxbI2LFjRURk6tSpcvToURk/fnzTm6y//vrrkp+f78qEgXSh/xF0rAEEGf2PoGMNACc42kQbhnHKY0KhkEybNk2mTZuW6JwAX6L/EXSsAQQZ/Y+gYw0AJySczu0XlX/sq9ZL/0V/s/ZzX9LDT9b0D7s2Jzu8DBCzYhUy4sT1Jc7+p2VplR5A4EbAm5MAsf136vPuMNQcPhA1Mis05OWqtY6O90twkdW8reZX8Ww/td5uiTnQK/+51a7cphO7fn2RWn/o2qfU+twyc83L+WUrq/us559+otdve8/L6QTW/Tv1n8N93fTezZQQsUxmtTbGXHCVWv/ZildNNachWn5/rAqfcYZaj336qe0xrM5x78/15xmdptPriQi995GEQs3SPQ3PaMG2IiJHh37iaJwjVw8y1Vot1p/vu/Ecw2mAmJfPa5wEiHkxj4TfJxoAAAAAgKBhEw0AAAAAgE1sogEAAAAAsIlNNAAAAAAANrGJBgAAAADAppBhJ68+hWprayUSicjI1/5JmrVuHve1uksO2h5n/xSLNOZZ3qUkhnt2V+uxrTs8u02/8DKFO9WixnFZLkukpqZGCgoKUn77X62BYTJacjMsmTL6HT3lMPevzhLEvfLLCr1PnaZN+oWTtEm7x9L/qRFu20atxw793bPbnLrjQ7U+s/vZnt2ml7bNNafSioiUTTAn0z6y+2312EldLjbVgroG9r7Yx1Rrt6C1emyLl/WUe95hwD7tvnLrfjpj1emm2qeDP7f1vX7u/8jKtqbja4YcUsf5bKz+jhnfWviOWh+x8XNT7X/6nP7Nkz1JkPt/6/zz1Hr3/9+o1sNvfmCqVd+u792Kfqvv3Qau08dee675NeJvva1fcz+7OP6a66T/eSUaAAAAAACb2EQDAAAAAGATm2gAAAAAAGxiEw0AAAAAgE1sogEAAAAAsCk33ROw0qFFrTRvGZ/Kt8XJ93uYwm3FSQr3tkf0VNG2H+i/12izQE8TdGLrf+rJeT3/+f2kx/Z7CnftDReq9YJnV6d4Jva9sGW9FOTH94PfEx7dSOFuHHquWs8pX5f02Pfc/GN9bFmf9NhuqPpXPZmy+EH98cyqH4zB/ZRjE59XOoS7dZFwOC+uFttekabZuM8qhdvLdNdLWzSo9ZlJj5weWgq3iMgrVebU1yuLzSncfjdn0zuSf9I14Eedh3h2e52+tzHpMZz0qWu9HgqZaxZvPLNwz0q1PtaF+9XqfHotGq/WvXxM1pK451ucu5c95TarJG6NVQq3FadJ3Bo3HqczNeG754/0vcRju/W+u2XynaZa8f98oh4bs7hNLYXb6javmjdVPbZYEt8v8ko0AAAAAAA2sYkGAAAAAMAmNtEAAAAAANjEJhoAAAAAAJvYRAMAAAAAYJNv07m3DT0uuUrgol1+Src7/H1zEnfZJD1V1MrlH9Wq9df6Fphq1ueuj52Tn6/WG+vq7E1ORMKnR9R67PMa22NYWVq1Rq1bJYJvXWD+Gff8oX9TuK1c26u/5IaanfrALGOVwp1b3FGtR6v2mWrWCayJzysV1v30EbU+8kFnj1t//u/fm8fwebLnyWI7d0soif5/bd96tX55x/4Jj5kKXv6c/N4DT1qs2x87TA++6pzvKlU9Dd3P2odbSkE48dc68sqL1Hr90OqEx3ST037cNkd/V5Pfj3jSVJvR/Rz1WDdSuHO7dVXrVs+xysTZ8z2vZFIKd5BZrQsn+5rua1qox+44/1jiEzuFWbv0NPTbuuh910pbF+3aqsc63dP9wxu3m2o9Ld7lJBm8Eg0AAAAAgE1sogEAAAAAsIlNNAAAAAAANrGJBgAAAADAJt8GiyXLywCVmzbvVetPn9lJrbf+Y/KhElqAmBWrc69a1EetF1+zMaE5fZ2XAWK9/zJOrZfJB2q95w/NAQROw8ngP1qAmBU3wmO8Fjr/bFPtyuv6q8fmyHq1vv+l3hajm9dAuEepeuSS8j/GfV5b1yjte1kMm0G8DBB7pUp/7LmyeIBnt5lt9KCYlq6MHTtkP0QsfFZPU82I1YtsdmUqSbEbLhm2COPxS4CYW8om6s+lZogeIuaV6M5dno2972eD1XrHh5IPRTo28gK13vrdirjPjcYGkYNJ31zK3LNzvVr/dbf+KZ2H15zsa6wCxLb9Qb9Gld2iX9OcmNL1IkfHa9cAq3N0uqfT9gEv7NWDzy59/5/iPo8dqRe5wd7t8Eo0AAAAAAA2sYkGAAAAAMAmNtEAAAAAANjEJhoAAAAAAJvYRAMAAAAAYFPIMAwj3ZP4utraWolEIjJMRptSKRfuWWk63mkKr54I6m2atxvu3rFBrc/ontpUSqe+9/EBtf5i7/Ypnol9UeO4LJclUlNTIwUF9lPR3fJNayC3Q5Hp+Oh+Zwmslffq6Z8lv0w+/TNThdu2UetOUn6zhZ/7X5Opj+nZZttTFqmvtyaf+ppqmbYGYBbuo7/FQGzjlhTPxFl69LaF+uPW7y9ZoNa9eA6Y7f3vl2vGgSVnqvX2o717a4Dw6RG17sY77GQLJ/3PK9EAAAAAANjEJhoAAAAAAJvYRAMAAAAAYBObaAAAAAAAbMpN9wRO9lXOWVSOi5wUeVZX12g6PmocdzR+rTJGIuOk2uG6mFr3+7yPfhFV636ed1S+nFu6Mve+aQ1IY4PpeKf3Zaz+mFr388/Ea4Zyv4qIxAJ4n/i6/xWZ+piebRqPZs/jSqatAZgZsXq1no7H9MMOHqOs1lEqnwNme//75ZoRO6L3qJfzMAye65yKk/73XTr33r17paSkJN3TAKSyslI6deqU8ttlDcAP6H8EHWsAQUb/I8js9L/vNtGNjY2yb98+yc/Pl7q6OikpKZHKysq0xOynSm1tbdafZyado2EYUldXJx07dpScnNT/x0PQ1kAm9UaiMukc6f/UyqTeSEYmnSdrILUyqTcSlUnnSP+nVib1RjIy5Tyd9L/v/pw7JyenaecfCoVERKSgoMDXd7hbgnCemXKOkYj+XnqpENQ1wDn6B/2fekE4R5HMOU/WQOpxjv5B/6deEM5RJDPO027/EywGAAAAAIBNbKIBAAAAALDJ15vovLw8ue+++yQvLy/dU/FUEM4zCOfohSDcb5wjrAThfgvCOYoE5zzdFoT7jXOElSDcb0E4R5HsPE/fBYsBAAAAAOBXvn4lGgAAAAAAP2ETDQAAAACATWyiAQAAAACwiU00AAAAAAA2sYkGAAAAAMAmX2+i582bJ6WlpdKiRQsZOHCgvPXWW+meUsJWrFghV111lXTs2FFCoZC89NJLcV83DEOmTZsmHTt2lJYtW8qwYcNk48aN6ZlsgmbMmCHnn3++5OfnS/v27WXMmDGyZcuWuGOy4TxTJZv6XyT71wD9775sWgPZ3v8irAG3ZVP/i2T/GqD/3UX/Z15vBG0N+HYT/fzzz8vkyZPlnnvukXXr1skll1wiI0aMkD179qR7agk5fPiw9OvXT+bMmaN+febMmTJr1iyZM2eOrFmzRoqKiuSyyy6Turq6FM80ceXl5TJhwgRZvXq1LFu2TKLRqAwfPlwOHz7cdEw2nGcqZFv/i2T/GqD/3ZVtayDb+1+ENeCmbOt/kexfA/S/e+j/zOyNwK0Bw6cuuOAC47bbbournXnmmcZdd92Vphm5R0SMxYsXN33e2NhoFBUVGQ888EBT7dixY0YkEjEee+yxNMzQHQcOHDBExCgvLzcMI3vP0wvZ3P+GEYw1QP8nJ5vXQBD63zBYA8nI5v43jGCsAfo/cfR/dvRGtq8BX74S3dDQIGvXrpXhw4fH1YcPHy6rVq1K06y8U1FRIdXV1XHnm5eXJ0OHDs3o862pqRERkTZt2ohI9p6n24LW/yLZ2Rv0f+KCtgaytTdYA4kJWv+LZGdv0P+Jof+/lA29ke1rwJeb6IMHD0osFpPCwsK4emFhoVRXV6dpVt756pyy6XwNw5ApU6bIkCFDpG/fviKSnefphaD1v0j29Qb9n5ygrYFs7A3WQOKC1v8i2dcb9H/i6P8TMvmcg7AGctM9gW8SCoXiPjcMw1TLJtl0vhMnTpQNGzbIypUrTV/LpvP0UhDvp2w5Z/rfHUG7r7LpfFkDyQvi/ZQt50z/Jy+I91M2nXMQ1oAvX4lu166dhMNh028lDhw4YPrtRTYoKioSEcma8500aZIsXbpU3nzzTenUqVNTPdvO0ytB63+R7OoN+j95QVsD2dYbrIHkBK3/RbKrN+j/5ND/J2TqOQdlDfhyE928eXMZOHCgLFu2LK6+bNkyGTx4cJpm5Z3S0lIpKiqKO9+GhgYpLy/PqPM1DEMmTpwoixYtkjfeeENKS0vjvp4t5+m1oPW/SHb0Bv3vnqCtgWzpDdaAO4LW/yLZ0Rv0vzvo/y9lYm8Ebg2kLMLMoeeee85o1qyZMX/+fGPTpk3G5MmTjdatWxu7du1K99QSUldXZ6xbt85Yt26dISLGrFmzjHXr1hm7d+82DMMwHnjgASMSiRiLFi0yPvzwQ+OGG24wOnToYNTW1qZ55vb95Cc/MSKRiLF8+XJj//79TR9HjhxpOiYbzjMVsq3/DSP71wD9765sWwPZ3v+GwRpwU7b1v2Fk/xqg/91D/2dmbwRtDfh2E20YhjF37lyjS5cuRvPmzY0BAwY0RaRnojfffNMQEdPHrbfeahjGl7Hv9913n1FUVGTk5eUZl156qfHhhx+md9IOaecnIsaCBQuajsmG80yVbOp/w8j+NUD/uy+b1kC2979hsAbclk39bxjZvwbof3fR/5nXG0FbAyHDMAx3XtMGAAAAACC7+fJ/ogEAAAAA8CM20QAAAAAA2MQmGgAAAAAAm9hEAwAAAABgE5toAAAAAABsYhMNAAAAAIBNbKIBAAAAALCJTTQAAAAAADaxiQYAAAAAwCY20QAAAAAA2MQmGgAAAAAAm/4XugQPfkmrhMkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 1200x400 with 5 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "perm = torch.randperm(784)\n",
    "\n",
    "batch = next(iter(train_loader))\n",
    "x = batch[0][:10]\n",
    "y = batch[1][:10]\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(12, 4))\n",
    "\n",
    "for i in range(5):\n",
    "    axs[i].imshow(x[i].squeeze().numpy())\n",
    "    \n",
    "fig, axs = plt.subplots(1, 5, figsize=(12, 4))\n",
    "x = x.view(-1, 28*28)\n",
    "x = x[:, perm]\n",
    "x = x.view(-1, 1, 28, 28)\n",
    "\n",
    "for i in range(5):\n",
    "    axs[i].imshow(x[i].squeeze().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameters=6.422K\n",
      "epoch=0, step=0: train loss=2.2717\n",
      "epoch=0, step=100: train loss=1.4825\n",
      "epoch=0, step=200: train loss=0.9727\n",
      "epoch=0, step=300: train loss=0.8701\n",
      "epoch=0, step=400: train loss=0.7045\n",
      "epoch=0, step=500: train loss=0.5553\n",
      "epoch=0, step=600: train loss=0.4980\n",
      "epoch=0, step=700: train loss=0.6311\n",
      "epoch=0, step=800: train loss=0.2293\n",
      "epoch=0, step=900: train loss=0.4658\n",
      "test loss=0.4080, accuracy=0.8724\n"
     ]
    }
   ],
   "source": [
    "# ConvNet on shuffled pixels\n",
    "n_kernels = 6\n",
    "convnet = ConvNet(input_size, n_kernels, output_size)\n",
    "convnet.to(device)\n",
    "print(f\"Parameters={sum(p.numel() for p in convnet.parameters())/1e3}K\")\n",
    "\n",
    "train(convnet, perm=perm)\n",
    "test(convnet, perm=perm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameters=6.442K\n",
      "epoch=0, step=0: train loss=2.3154\n",
      "epoch=0, step=100: train loss=1.5152\n",
      "epoch=0, step=200: train loss=0.6278\n",
      "epoch=0, step=300: train loss=0.4071\n",
      "epoch=0, step=400: train loss=0.4127\n",
      "epoch=0, step=500: train loss=0.3370\n",
      "epoch=0, step=600: train loss=0.2718\n",
      "epoch=0, step=700: train loss=0.4750\n",
      "epoch=0, step=800: train loss=0.2036\n",
      "epoch=0, step=900: train loss=0.3137\n",
      "test loss=0.3322, accuracy=0.9019\n"
     ]
    }
   ],
   "source": [
    "# MLP on shuffled pixels\n",
    "input_size = 28*28  \n",
    "output_size = 10   \n",
    "\n",
    "n_hidden = 8\n",
    "mlp = MLP(input_size, n_hidden, output_size)\n",
    "mlp.to(device)\n",
    "print(f\"Parameters={sum(p.numel() for p in mlp.parameters())/1e3}K\")\n",
    "\n",
    "train(mlp, perm=perm)\n",
    "test(mlp, perm=perm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The convolutional network's performance drops when we permute the pixels, but the MLP's performance stays the same.\n",
    "* ConvNet makes the assumption that pixels lie on a grid and are stationary/local.\n",
    "* It loses performance when this assumption is wrong.\n",
    "* The fully-connected network does not make this assumption.\n",
    "* It does less well when it is true, since it doesn't take advantage of this prior knowledge.\n",
    "* But it doesn't suffer when the assumption is wrong."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch-cpu]",
   "language": "python",
   "name": "conda-env-torch-cpu-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
